# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import os
import sys
from typing import Dict

from azure.core.exceptions import ClientAuthenticationError
from azure.core.pipeline.transport import HttpRequest
from azure.core.pipeline.policies import HTTPPolicy
from azure.core.pipeline import PipelineRequest, PipelineResponse

from .._internal.msal_managed_identity_client import MsalManagedIdentityClient


class AzureArcCredential(MsalManagedIdentityClient):
    def get_unavailable_message(self, desc: str = "") -> str:
        return f"Azure Arc managed identity configuration not found in environment. {desc}"


def _get_request(url: str, scope: str, identity_config: Dict) -> HttpRequest:
    if identity_config:
        raise ClientAuthenticationError(
            message="User assigned managed identities are not supported by Azure Arc. To authenticate with the system "
            "assigned identity omit the client id when constructing the credential, and if authenticating with "
            "DefaultAzureCredential ensure the AZURE_CLIENT_ID environment variable is not set."
        )

    request = HttpRequest("GET", url)
    request.format_parameters(dict({"api-version": "2020-06-01", "resource": scope}, **identity_config))
    return request


def _get_secret_key(response: PipelineResponse) -> str:
    # expecting header containing path to secret key file
    header = response.http_response.headers.get("WWW-Authenticate")
    if not header:
        raise ClientAuthenticationError(message="Did not receive a value from WWW-Authenticate header")

    # expecting header with structure like 'Basic realm=<file path>'
    try:
        key_file = header.split("=")[1]
    except IndexError as ex:
        raise ClientAuthenticationError(
            message="Did not receive a correct value from WWW-Authenticate header: {}".format(header)
        ) from ex

    try:
        _validate_key_file(key_file)
    except ValueError as ex:
        raise ClientAuthenticationError(message="The key file path is invalid: {}".format(ex)) from ex

    with open(key_file, "r", encoding="utf-8") as file:
        try:
            return file.read()
        except Exception as error:  # pylint:disable=broad-except
            # user is expected to have obtained read permission prior to this being called
            raise ClientAuthenticationError(
                message="Could not read file {} contents: {}".format(key_file, error)
            ) from error


def _get_key_file_path() -> str:
    """Returns the expected path for the Azure Arc MSI key file based on the current platform.

    Only Linux and Windows are supported.

    :return: The expected path.
    :rtype: str
    :raises ValueError: If the current platform is not supported.
    """
    if sys.platform.startswith("linux"):
        return "/var/opt/azcmagent/tokens"
    if sys.platform.startswith("win"):
        program_data_path = os.environ.get("PROGRAMDATA")
        if not program_data_path:
            raise ValueError("PROGRAMDATA environment variable is not set or is empty.")
        return os.path.join(f"{program_data_path}", "AzureConnectedMachineAgent", "Tokens")
    raise ValueError(f"Azure Arc MSI is not supported on this platform {sys.platform}")


def _validate_key_file(file_path: str) -> None:
    """Validates that a given Azure Arc MSI file path is valid for use.

    A valid file will:
        1. Be in the expected path for the current platform.
        2. Have a `.key` extension.
        3. Be at most 4096 bytes in size.

    :param str file_path: The path to the key file.
    :raises ClientAuthenticationError: If the file path is invalid.
    """
    if not file_path:
        raise ValueError("The file path must not be empty.")

    if not os.path.exists(file_path):
        raise ValueError(f"The file path does not exist: {file_path}")

    expected_directory = _get_key_file_path()
    if not os.path.dirname(file_path) == expected_directory:
        raise ValueError(f"Unexpected file path from HIMDS service: {file_path}")

    if not file_path.endswith(".key"):
        raise ValueError("The file path must have a '.key' extension.")

    if os.path.getsize(file_path) > 4096:
        raise ValueError("The file size must be less than or equal to 4096 bytes.")


class ArcChallengeAuthPolicy(HTTPPolicy):
    """Policy for handling Azure Arc's challenge authentication"""

    def send(self, request: PipelineRequest) -> PipelineResponse:
        request.http_request.headers["Metadata"] = "true"
        response = self.next.send(request)

        if response.http_response.status_code == 401:
            secret_key = _get_secret_key(response)
            request.http_request.headers["Authorization"] = "Basic {}".format(secret_key)
            response = self.next.send(request)

        return response
